おワインの品質を AmazonSageMaker のXGBoostで予測してみた
はじめに
おはようございます、もきゅりんです。
最近は個人的な取り組みの一環として、機械学習の学習に取り組んでいます。
前回 は、LinearLearnerで住宅価格推定をしました。
今回は、XGBoost を使ってみたく、Wine Quality Data Set
を使っておワインの品質を分類します。
(sklearn.datasets
のワインデータサイズは178しかないため、4898あるこちらのおワインデータセットを選びました)
XGBoostは、(書籍やネットで詳しい説明を読んでいただければと思いますが)より単純で弱い決定木モデルのセットから推定のアンサンブルを組み合わせることで、ターゲット変数の正確な予測を試行します。
現在はXGBoostに近いLightGBMモデルのほうが人気のようですが、機械学習のコンペティションにおいてよく使われていたようです。
なお、自分は専門的なデータサイエンティストでも何でもないので、無駄、非効率な作業を行っているかもしれない点、ご了承下さい。
前提
- データを格納するS3バケットがあること
- Jupyterノートブックが作成されていること
- IAM権限を設定・更新できること
こちらの詳細の作業については下記を参照下さい。
はじめてのSageMaker みんな大好きアイリスデータを使って組み込みアルゴリズムで分類してみる
やること
ワインの品質を、11つの特徴(アルコール度数とか水素イオン指数とか)で分類しようぜ! が趣旨です。
- データのロード、データを探索、処理、S3アップロード
- モデルでの学習
- モデルのデプロイ
- モデルの検証
- 後片付け
1. データのロード、データを探索、処理、S3アップロード
データのロードとデータを軽く探索
import pandas as pd import seaborn as sns import numpy as np import matplotlib.pyplot as plt
import urllib.request urllib.request.urlretrieve("http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv", 'red_wine') df = pd.read_csv('./red_wine',sep=';',)
データ内容を確認します。
df.info()
Null
はなさそうです。
品質のデータが偏っているという情報が記載されているので、確認します。
df['quality'].value_counts()
確かに、5,6の品質が多く、非常に悪い、非常に良いというデータは少ない ですね。
じゃ仮に、すべてのワインを最も一番多い品質である5に予測したらどんなものか見てみます。
df['quality'].value_counts()/len(df)
0.425891
なので、最低限今回の分類はこの予測率を超えないとダメですね。
特徴量がどのように影響しているかをヒートマップで大雑把に確認しておきます。
(相関係数はあくまで線形関係を捉えるものであり、必ずしも因果関係があるわけではなく、見かけ上の相関もあるため注意が必要です。)
correlation_matrix = df.corr().round(2) plt.figure(figsize=(12, 8)) sns.heatmap(data=correlation_matrix, annot=True,cmap="YlGnBu")
うーん、 alcohol
と sulphates
に微弱の正の相関が、volatile acidity
に微弱の負の相関がある、ような感じですが、ほとんど無関係といってもよさような値だと思います。
とはいえ、足掛かりにはなりそうな情報です。
データの分割
第一列に教師ラベルが必要なので、データ列を並べ替えます。
df = df[df.columns[::-1]]
訓練データ、テストデータ、検証データの分割を行います。
それぞれ全データの60%、25%、15%としています。
train_set, valid_set, test_set = np.split(df.sample(frac=1), [int(.6*len(df)), int(.85*len(df))])
DataFrame
を ndarray
に変換します。
train_set = train_set.values valid_set = valid_set.values test_set = test_set.values
環境変数とロールを確認
%%time import os import boto3 import re import numpy as np from sagemaker import get_execution_role role = get_execution_role() region = boto3.Session().region_name bucket='YOUR_BUCKET_NAME' prefix = 'sagemaker/xgboost-redwine' # customize to your bucket where you have stored the data bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)
データのアップロード
それぞれのデータをS3バケットにアップロードします。
アップロードする関数です。
def convert_data(feature_dim): for data_partition_name, data_partition in data_partitions: print('{}: {} {}'.format(data_partition_name, data_partition[0].shape, data_partition[1].shape)) labels = [t.tolist() for t in data_partition[:,0]] # ラベルの抽出 features = [t.tolist() for t in data_partition[:,1:feature_dim + 1]] # 特徴量の抽出 if data_partition_name != 'test': examples = np.insert(features, 0, labels, axis=1) else: examples = features np.savetxt('data.csv', examples, delimiter=',') key = "{}/{}/data".format(prefix,data_partition_name) url = 's3://{}/{}'.format(bucket, key) boto3.Session().resource('s3').Bucket(bucket).Object(key).upload_file('data.csv') print('Done writing to {}'.format(url)) def get_data(data): return 's3://{}/{}/{}'.format(bucket, prefix, data) def set_channel(data,content_type): return sagemaker.session.s3_input(data, content_type=content_type)
data_partitions = [('train', train_set), ('validation', valid_set), ('test', test_set)] # 特徴量が引数 convert_data(11)
2. モデルでの学習
XGBoost
のモデルを作成します。
import sagemaker from sagemaker.amazon.amazon_estimator import get_image_uri container = get_image_uri(boto3.Session().region_name, 'xgboost','0.90-1')
前ステップでアップロードしたS3から訓練データと検証データをダウンロードし、トレーニングの出力を保存する場所を設定します。
#Load the dataset from S3 train_data = get_data('train') validation_data = get_data('validation') s3_output_location = 's3://{}/{}/{}'.format(bucket, prefix, 'xgboost_model_sdk')
モデルのコンテナを取得します。
xgb_model = sagemaker.estimator.Estimator(container, role, train_instance_count=1, train_instance_type='ml.m4.xlarge', train_volume_size = 5, output_path=s3_output_location, sagemaker_session=sagemaker.Session())
モデルのハイパーパラメータを設定します。
特徴量は 11
です。
ゴリゴリ回して early-stopping
で制御します。
基本的な戦術は、variance
の高いモデルにする。
それから、過学習を抑制していく、にしています。
ハイパーパラメータに関する詳細はドキュメントをご確認ください。
xgb_model.set_hyperparameters( objective = 'multi:softmax', max_depth = 9, num_class = 9, eta = .1, eval_metric = 'mlogloss', gamma = .1, min_child_weight = 1, subsample = .7, colsample_bytree = .7, num_round = 1000, early_stopping_rounds = 10 )
モデルが利用するデータチャネルを train
と validation
で作ります。
train_channel = set_channel(train_data, 'text/csv') valid_channel = set_channel(validation_data, 'text/csv') data_channels = {'train': train_channel, 'validation': valid_channel}
訓練を開始します。
xgb_model.fit(inputs=data_channels, logs=True)
出力されたログを見ているだけで、あまり良くなさそうな結果とは思いますが、いちおうグラフを見ておきます。
訓練データと検証データで大きく差があるのが分かります。
今回はXGBoostを使うのが目的なので、これで良しとしましょう。
3. モデルのデプロイ
訓練されたモデルをエンドポイントにデプロイすることもできますが、特にエンドポイントは必要ないのでバッチ変換します。
11つの特徴量のみを与えられているテストデータを元に、ワインの品質を分類します。
batch_input = 's3://{}/{}/test/data'.format(bucket, prefix) batch_output = 's3://{}/{}/batch-inference'.format(bucket, prefix) transformer = linear_model.transformer(instance_count=1, instance_type='ml.c4.xlarge', output_path=batch_output) transformer.transform(data=batch_input, data_type='S3Prefix', content_type='text/csv', split_type='Line') transformer.wait()
4. モデルの検証
モデルから予測された予測を答え合わせします。
テストデータから予測したファイルをダウンロードします。
boto3.resource('s3').Bucket(bucket).download_file(prefix + '/batch-inference/data.out', 'test_wine_results')
y_test = test_set[:,0] y_pred = np.loadtxt('./test_wine_results')
正解率です。
from sklearn.metrics import accuracy_score acc = accuracy_score(y_test, y_pred) acc
0.675
なので、最低限のしきい値(0.425891
)は超えています。
特徴量をいじったり、ハイパーパラメータを調整することで改善するかと思います。
なお5,6に偏っているデータのため、混同行列 *1で確認してみます。
from sklearn.metrics import classification_report print(classification_report(y_test, y_pred))
品質5,6に対する分類予測は、比較的悪くはないのですが、サイズの少ない3,4,8は全然うまくいっていないのが分かります。
ちなみに、ネット上ではこれらのワインの品質を「優・可・不」のような3クラスにまとめて分類予測するものが多かったです。
その方法だと、90%以上の予測率が達成されていました。
5. 後片付け
不用なインスタンスは削除しましょう。
S3バケットも忘れずサヨナラしましょう。
以上です。
引き続き学習を進めていきます。
どなたかのお役に立てば幸いです。
参考:
脚注
- 混同行列については別途書籍またはネットで調べてみて下さい。 ↩